import numpy as np
import torch
import os
import shutil
from ASVBaselineTool.feature_tools import *
import soundfile as sf
import torchaudio
import torch.nn as nn
from ASVBaselineTool.feature_tools import *

#FGSM攻击框架
class FRAMEFGSM():
    # 模型  标签  原始数据集路径  对抗样本保存的路径
    def __init__(self,model,labelPath,datasetPath,AdvsetPath):
        #获取label和fname
        with open(labelPath,"r") as f:
            Labels=f.readlines()
        self.AFnames=[item.split(" ")[0] for item in Labels]
        ALabels = [item.split(" ")[-1].strip() for item in Labels]
        #label meta
        self.LabelMeta={}
        for f,l in zip(self.AFnames, ALabels):
            if l == "genuine":
                l = torch.Tensor([0.0,1.0])
            elif l == "fake":
                l = torch.Tensor([1.0,0.0])
            self.LabelMeta[f]=l
        #保存的路径 和上面是同样索引
        self.model = model
        self.model.cuda()
        self.model.eval()

        self.datasetPath=datasetPath
        #保存对抗样本路径
        self.AdvSavePath=AdvsetPath


    def fgsm_attack(self,srcwave, epsilon, data_grad):
        # 收集数据梯度的元素符号
        sign_data_grad = data_grad.sign()
        # 通过调整输入图像的每个像素来创建扰动图像
        perturbed_wav = srcwave + epsilon * sign_data_grad
        # 添加剪切以维持[-1,1]范围
        perturbed_wav = torch.clamp(perturbed_wav, -1, 1)
        # 返回被扰动的wav
        return perturbed_wav

    def feature_select(self,feature,wav:torch.Tensor):
        if feature=="spec_2048":
            return get_SPEC_2048(wav)
        elif feature == "spec_1024":
            return get_SPEC_1024(wav)
        elif feature=="mfcc_40":
            return get_MFCC_40(wav)
        elif feature=="mfcc_80":
            return get_MFCC_80(wav)
        elif feature=="lfcc_70":
            return get_LFCC_70(wav)
        elif feature=="lfcc_60":
            return get_LFCC_60(wav)
        elif feature=="raw":
            wav=wav.reshape(1,-1)
            return wav

    def model_bn_to_eval(self):
        self.model.train()
        self.model.first_bn.eval()
        self.model.block0[0].bn2.eval()
        self.model.block1[0].bn1.eval()
        self.model.block1[0].bn2.eval()
        self.model.block2[0].bn1.eval()
        self.model.block2[0].bn2.eval()
        self.model.block3[0].bn1.eval()
        self.model.block3[0].bn2.eval()
        self.model.block4[0].bn1.eval()
        self.model.block4[0].bn2.eval()
        self.model.block5[0].bn1.eval()
        self.model.block5[0].bn2.eval()
        self.model.bn_before_gru.eval()

    def Attack(self,feature:str,eps=0.2):
        allff=os.listdir(self.AdvSavePath)
        print("attack eps:"+str(eps))
        num_f=len(self.AFnames)
        print("总样本：" + str(num_f))
        num_f=0
        #攻击成功的样本的转移率
        num_as=0
        for item in self.AFnames:
            if item in allff:
                print("已经攻击成功")
                continue
            num_f += 1
            src_wav,sr=sf.read(os.path.join(self.datasetPath,item))
            src_wav=torch.Tensor(src_wav)
            if feature=="raw":
                #将bn层固定
                self.model_bn_to_eval()
            adv_wav=src_wav
            adv_wav.requires_grad=True
            adv_feature=self.feature_select(feature,adv_wav)
            adv_feature=adv_feature.cuda()
            #计算预测概率
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # adv_feature = (adv_feature).to(device)
            batch_x = self.model(adv_feature)
            init_pred = torch.max(batch_x,dim=1)[1].item()#计算最大预测
            #获取原始分类 以及目标分类
            y=self.LabelMeta[item]
            targety=-1*y+1
            targety=targety.cuda()
            targety = targety.reshape(1, 2)
            init_y = y.max(0, keepdim=True)[1].item()
            #如果等于就说明本来就分类错误，直接跳过
            if init_pred!= init_y:
                print("原始分类错误 pass")
                continue
            #分类正确则攻击
            else:
                lossfun=nn.CrossEntropyLoss()
                loss=lossfun(batch_x,targety)
                self.model.zero_grad()
                loss.backward()
                #这里计算梯度 等等观察下它
                data_grad = adv_wav.grad.data
                perturbed_data = self.fgsm_attack(adv_wav, eps, data_grad)
                #攻击后重新计算特征送入网络判断
                temp_adv_feature=self.feature_select(feature,wav=perturbed_data)
                temp_adv_feature=temp_adv_feature.to(device)
                temp_batch_x = self.model(temp_adv_feature)
                temp_init_pred = temp_batch_x.max(1)[1].item() # 计算最大预测
                #如果攻击失败
                if temp_init_pred==init_y:
                    print("攻击失败 pass")
                    continue
                else:
                    # # 保存原始音频
                    # shutil.copy(os.path.join(self.datasetPath, item), os.path.join(self.sucess_original_path, item))

                    sp=os.path.join(self.AdvSavePath, item)
                    print("攻击成功并保存值路径:"+sp)
                    num_as+=1
                    perturbed_data = perturbed_data.detach().cpu().numpy()
                    sf.write(sp, perturbed_data, sr)
        print("总样本：" + str(num_f))
        print("攻击成功样本数量:"+str(num_as))
        print("攻击成功率:"+str(num_as/num_f))




















